"""
Embeddings
==========

The ``Embeddings`` and ``EmbeddingsView`` classes combine the 2 dimensional
embeddings generated for a graph and the associated node labels for each embedding.

The embeddings generated follow the natural ordering of the provided networkx Graph
object, but as they are positionally correlated and the source graph object may diverge
in that natural ordering due to mutations, it is often helpful to combine a snapshot
of the labels at the same time the embedding was genrated with the embedding to
guarantee the appropriate latent positions are returned for each node.
"""

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from collections import OrderedDict
from typing import Any

import numpy as np
from beartype import beartype

from graspologic.types import Tuple


class Embeddings:
    @beartype
    def __init__(self, labels: np.ndarray, embeddings: np.ndarray):
        """
        ``Embeddings`` is an iterable, indexed interface over the parallel numpy arrays
        that are generated by our embedding functions.

        >>> import numpy as np
        >>> raw_embeddings = np.array([[1,2,3,4,5], [6,4,2,0,8]])
        >>> labels = np.array(["monotonic", "not_monotonic"])
        >>> embeddings_obj = Embeddings(labels, raw_embeddings)
        >>> embeddings_obj[0]
        ('monotonic', array([1, 2, 3, 4, 5]))
        >>> len(embeddings_obj)
        2
        >>> list(iter(embeddings_obj))
        [('monotonic', array([1, 2, 3, 4, 5])), ('not_monotonic', array([6, 4, 2, 0, 8]))]
        >>> embeddings_lookup = embeddings_obj.as_dict()
        >>> embeddings_lookup["not_monotonic"]
        array([6, 4, 2, 0, 8])

        Parameters
        ----------
        labels : np.ndarray
            The node labels that are positionally correlated with the embeddings.
            The dtype of labels is any object stored in a networkx Graph object,
            though type uniformity will be required
        embeddings : np.ndarray
            The embedded values generated by the embedding technique.

        Raises
        ------
        beartype.roar.BeartypeCallHintParamViolation if the types are invalid
        ValueError if the row count of labels does not equal the row count of embeddings
        """
        if labels.shape[0] != embeddings.shape[0]:
            raise ValueError(
                f"labels and embeddings must have the same number of "
                f"rows, labels: {labels.shape[0]}, "
                f"embeddings: {embeddings.shape[0]}"
            )
        self._labels = labels
        self._embeddings = embeddings

    def labels(self) -> np.ndarray:
        return self._labels

    def embeddings(self) -> np.ndarray:
        return self._embeddings

    def as_dict(self) -> "EmbeddingsView":
        return EmbeddingsView(self)

    def __len__(self) -> int:
        return self._labels.shape[0]

    def __getitem__(self, index: int) -> Tuple[Any, np.ndarray]:
        if index >= len(self):
            raise IndexError(
                f"index '{index}' is out of bounds of "
                f"embeddings length '{len(self)} "
            )
        else:
            return self._labels[index], self._embeddings[index]

    def __iter__(self) -> "_EmbeddingsIter":
        return _EmbeddingsIter(self)


class _EmbeddingsIter:
    @beartype
    def __init__(self, embeddings: Embeddings):
        self._embeddings = embeddings
        self._index = 0

    def __next__(self) -> Tuple[Any, np.ndarray]:
        if self._index >= len(self._embeddings):
            raise StopIteration
        else:
            result = self._embeddings[self._index]
            self._index += 1
            return result

    def __iter__(self) -> "_EmbeddingsIter":
        return self


class EmbeddingsView(OrderedDict):
    @beartype
    def __init__(self, embeddings: Embeddings):
        if embeddings is None:
            raise ValueError("embeddings must not be None")
        if not isinstance(embeddings, Embeddings):
            raise TypeError(
                "embeddings must be a graspologic.pipeline.embed.Embeddings"
            )
        super(EmbeddingsView, self).__init__()
        for label, embedding in embeddings:
            self[label] = embedding
